#include <gtest/gtest.h>
#include <fstream>

#include "c/ddk/graph/model.h"
#include "c/ddk/graph/attr_value.h"
#include "c/ddk/graph/handle_types.h"
#include "c/ddk/graph/op_config.h"
#include "c/ddk/graph/graph.h"
#include "c/ddk/graph/operator.h"
#include "graph/csrc/resource_manager.h"
#include "c/ddk/graph/context.h"
#include "graph/shape.h"
#include "graph/model.h"
#include "framework/graph/utils/op_desc_utils.h"
#include "framework/graph/utils/attr_utils.h"
#include "c/ddk/model_manager/hiai_model_build_options.h"
#include "c/ddk/model_manager/hiai_built_model.h"

constexpr uint32_t INPUT_NUM = 2;
constexpr uint32_t OUTPUT_NUM = 1;

using namespace ge;
using namespace std;
using namespace hiai;

class UTEST_c_model : public testing::Test {
public:
    void SetUp()
    {
        resMgr = HIAI_IR_ResourceManagerCreate();
    }
 
    void TearDown()
    {
        HIAI_IR_ResourceManagerDestroy(&resMgr);
    }

public:
    ResMgrHandle resMgr = nullptr;
    ModelHandle CreateModel()
    {
        const char* name = "modelName";
        ModelHandle modelHandle = HIAI_IR_CreateModel(resMgr, name);
        return modelHandle;
    }

    GraphHandle CreateGraph()
    {
        const char* name = "graph_defalut";
        GraphHandle graph = HIAI_IR_GraphCreate(resMgr, name);
        return graph;
    }
 
    OpHandle CreateData()
    {
        const char* name = "data1";
        uint32_t shapeNum = 4;
        int64_t inputShape[4] = { 10, 4, 120, 440 };
        OpHandle dataOp = HIAI_IR_CreatePlaceHodler(
            resMgr, name, HIAI_Format::HIAI_FORMAT_NCHW, HIAI_DATATYPE_FLOAT32, inputShape, shapeNum);
        return dataOp;
    }

    OpHandle CreateData2()
    {
        const char* name = "data2";
        uint32_t shapeNum = 4;
        int64_t inputShape[4] = { 10, 4, 120, 440 };
        OpHandle dataOp = HIAI_IR_CreatePlaceHodler(
            resMgr, name, HIAI_Format::HIAI_FORMAT_NCHW, HIAI_DATATYPE_FLOAT32, inputShape, shapeNum);
        return dataOp;
    }
 
    OpHandle CreateConv(OpHandle dataOp, OpHandle dataOp2)
    {
        HIAI_IR_Input input[2]; // 根据输入个数确定数组大小
        input[0].op = dataOp;
        input[0].outIndex = 0;

        input[1].op = dataOp2;
        input[1].outIndex = 0;
 
        int64_t stridesShape[] = {2, 2};
        AttrHandle stridesAttr = HIAI_IR_CreateInt64LstAttr(resMgr, stridesShape, 2);
        HIAI_IR_Params params[1];
        params[0].name = "strides";
        params[0].attrValue = stridesAttr;
 
        HIAI_IR_OpConfig opConfig;
        opConfig.opType = "Convolution";
        opConfig.opName = "conv1";
        opConfig.input = input;
        opConfig.inputNums = 2;
        opConfig.params = params;
        opConfig.paramsNums = 1;
 
        OpHandle convOp = HIAI_IR_CreateOp(resMgr, &opConfig);
        return convOp;
    }
};

TEST_F(UTEST_c_model, CreateModel)
{
    ModelHandle modelHandle = CreateModel();
    EXPECT_NE(modelHandle, nullptr);

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    BasePtr basePtr = resMgrPtr->GetSrcPtr(modelHandle);

    std::shared_ptr<ge::Model> modelPtr = std::dynamic_pointer_cast<ge::Model>(basePtr);
    EXPECT_EQ("modelName", modelPtr->GetName());
}

TEST_F(UTEST_c_model, CreateBuiltModel)
{
    ModelHandle model = CreateModel();
    EXPECT_NE(model, nullptr);

    GraphHandle graph = CreateGraph();
    EXPECT_NE(graph, nullptr);

    OpHandle dataOp = CreateData();
    EXPECT_NE(dataOp, nullptr);
    OpHandle dataOp2 = CreateData2();
    EXPECT_NE(dataOp2, nullptr);

    OpHandle convOp = CreateConv(dataOp, dataOp2);
    EXPECT_NE(convOp, nullptr);

    ConstOpHandle inputs[INPUT_NUM] = { dataOp, dataOp2 };
    HIAI_Status status = HIAI_IR_SetInputs(resMgr, graph, inputs, INPUT_NUM);
    EXPECT_EQ(HIAI_SUCCESS, status);

    ConstOpHandle outputs[OUTPUT_NUM] = { convOp };
    status = HIAI_IR_SetOutputs(resMgr, graph, outputs, OUTPUT_NUM);
    EXPECT_EQ(HIAI_SUCCESS, status);

    status = HIAI_IR_SetGraph(resMgr, graph, model);
    EXPECT_EQ(HIAI_SUCCESS, status);

    HIAI_MR_ModelBuildOptions* buildOptions = nullptr;
    HIAI_MR_BuiltModel* builtModel = nullptr;

    status = HIAI_IR_CreateBuiltModel(resMgr, model, buildOptions, &builtModel);
    EXPECT_EQ(HIAI_SUCCESS, status);

    HIAI_MR_ModelBuildOptions_Destroy(&buildOptions);
    HIAI_MR_BuiltModel_Destroy(&builtModel);
}

TEST_F(UTEST_c_model, DumpModel)
{
    ModelHandle model = CreateModel();
    EXPECT_NE(model, nullptr);

    GraphHandle graph = CreateGraph();
    EXPECT_NE(graph, nullptr);

    OpHandle dataOp = CreateData();
    EXPECT_NE(dataOp, nullptr);
    OpHandle dataOp2 = CreateData2();
    EXPECT_NE(dataOp2, nullptr);

    OpHandle convOp = CreateConv(dataOp, dataOp2);
    EXPECT_NE(convOp, nullptr);

    ConstOpHandle inputs[INPUT_NUM] = { dataOp, dataOp2 };
    HIAI_Status status = HIAI_IR_SetInputs(resMgr, graph, inputs, INPUT_NUM);
    EXPECT_EQ(HIAI_SUCCESS, status);

    ConstOpHandle outputs[OUTPUT_NUM] = { convOp };
    status = HIAI_IR_SetOutputs(resMgr, graph, outputs, OUTPUT_NUM);
    EXPECT_EQ(HIAI_SUCCESS, status);

    status = HIAI_IR_SetGraph(resMgr, graph, model);
    EXPECT_EQ(HIAI_SUCCESS, status);

    const char* file = "model.txt";
    status = HIAI_IR_DumpModel(resMgr, model, file);
    EXPECT_EQ(HIAI_SUCCESS, status);
}

TEST_F(UTEST_c_model, SetGraph)
{
    ModelHandle model = CreateModel();
    EXPECT_NE(model, nullptr);

    GraphHandle graph = CreateGraph();
    EXPECT_NE(graph, nullptr);

    OpHandle dataOp = CreateData();
    EXPECT_NE(dataOp, nullptr);
    OpHandle dataOp2 = CreateData2();
    EXPECT_NE(dataOp2, nullptr);

    OpHandle convOp = CreateConv(dataOp, dataOp2);
    EXPECT_NE(convOp, nullptr);

    ConstOpHandle inputs[INPUT_NUM] = { dataOp, dataOp2 };
    HIAI_Status status = HIAI_IR_SetInputs(resMgr, graph, inputs, INPUT_NUM);
    EXPECT_EQ(HIAI_SUCCESS, status);

    ConstOpHandle outputs[OUTPUT_NUM] = { convOp };
    status = HIAI_IR_SetOutputs(resMgr, graph, outputs, OUTPUT_NUM);
    EXPECT_EQ(HIAI_SUCCESS, status);

    status = HIAI_IR_SetGraph(resMgr, graph, model);
    EXPECT_EQ(HIAI_SUCCESS, status);
}
